import numpy as np
from mytbe import executor
import csv
import argparse

def add_dynamic(file, soc_version):
    csv_reader = csv.reader(open(file))
    csv_writer = csv.writer(open('result_'+file, 'a'))
    rows = [row for row in csv_reader]
    for i in range(1, len(rows)):
        arg_data = rows[i]
        rows[i][-1] = -1
        x = np.random.random([int(arg_data[4]), int(arg_data[2]), int(arg_data[1]), int(arg_data[3])]).astype("float16")
        input_filter = np.random.random([int(arg_data[7]), int(arg_data[6]), int(arg_data[3]), int(arg_data[5])]).astype("float16")


        inputs_info = [{"shape": [int(arg_data[4]), int(arg_data[2]), int(arg_data[1]), int(arg_data[3])],  "dtype": "float16", "format": "NHWC"},
			   {"shape": [int(arg_data[7]), int(arg_data[6]), int(arg_data[3]), int(arg_data[5])],  "dtype": "float16", "format": "HWCN"}]

        out_h = (int(arg_data[2])+int(arg_data[9])+int(arg_data[9])-int(arg_data[7])+2) / int(arg_data[10]) + 1

        out_w = (int(arg_data[1])+int(arg_data[8])+int(arg_data[8])-int(arg_data[6])+2) / int(arg_data[10]) + 1

        outputs_info = [{"shape": [int(arg_data[4]), int(out_h), int(out_w), int(arg_data[5])], "dtype": "float16", "format": "NHWC"}]

        real_out = [{"shape": [int(arg_data[4]), int(out_h), int(out_w), int(arg_data[5])], "dtype": "float16", "format": "NHWC"}]
            
        strides = [1, int(arg_data[10]), int(arg_data[11]), 1]
        pads = [int(arg_data[8]), int(arg_data[8]), int(arg_data[9]), int(arg_data[9])]
        dilations = [1, 1, 1, 1]
   
        attr_dict = {"strides":strides, "pads":pads, "dilations":dilations}

        # executor.op_debug_level = 1

        result, runtime = executor.acl_om("Conv2D", inputs_info, [x, input_filter], outputs_info, attr_dict, soc_version=soc_version,
					      device_id=0, real_out=real_out)
        print(runtime[2])
        rows[i][-1] = runtime[2]["task_time"]
        csv_writer.writerow(rows[i])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='gemm performance test')
    parser.add_argument('--file', default='gemm_f16.csv', type=str, help='gemm input')
    parser.add_argument('--soc_version', default='Ascend910', type=str, help='chip')
    args = parser.parse_args()
    file = args.file
    soc_version = args.soc_version
    add_dynamic(file, soc_version)

